{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "essential-mailing",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1.],\n",
       "        [0., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1.],\n",
       "        [1., 1., 0., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1.],\n",
       "        [1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 1.],\n",
       "        [1., 1., 1., 1., 0., 1., 1., 0., 1., 0., 1., 1., 1.]]),\n",
       " array([0, 0, 0, 0, 0]))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "#加载数据\n",
    "def load_data():\n",
    "    with open('数字01.txt') as fr:\n",
    "        lines = fr.readlines()\n",
    "\n",
    "    #最后一列留1,用来计算b\n",
    "    x = np.ones((len(lines), 13), dtype=float)\n",
    "    y = np.empty(len(lines), dtype=int)\n",
    "\n",
    "    for i in range(len(lines)):\n",
    "        line = lines[i].strip().split(',')\n",
    "        x[i, :12] = line[:12]\n",
    "        y[i] = line[12]\n",
    "\n",
    "    return x, y\n",
    "\n",
    "\n",
    "x, y = load_data()\n",
    "x[:5], y[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "communist-courage",
   "metadata": {},
   "outputs": [],
   "source": [
    "#常量\n",
    "N, M = x.shape\n",
    "w = np.empty(13, dtype=float)\n",
    "w.fill(1 / 13)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "quantitative-running",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.6997597020563492, 0)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#预测函数\n",
    "def predict(x):\n",
    "    z = w.dot(x)\n",
    "    a = 1 / (1 + np.exp(-z))\n",
    "    return a\n",
    "\n",
    "\n",
    "predict(x[0]), y[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "identified-phenomenon",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.6997597020563492, 21.39089054300221)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#求loss\n",
    "def get_loss(x, y):\n",
    "    a = predict(x)\n",
    "    pred = 1 if a > 0.5 else 0\n",
    "\n",
    "    #如果预测正确,pred - y为0,则loss为0\n",
    "    #如果预测错误,loss=(pred - y) * a\n",
    "    return (pred - y) * a\n",
    "\n",
    "\n",
    "#求总loss\n",
    "def total_loss():\n",
    "    _sum = 0\n",
    "    for i in range(N):\n",
    "        _sum += get_loss(x[i], y[i])\n",
    "    return _sum\n",
    "\n",
    "\n",
    "get_loss(x[0], y[0]), total_loss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "surprising-japanese",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1.])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_gradient(x, y):\n",
    "    pred = 1 if predict(x) > 0.5 else 0\n",
    "\n",
    "    #如果预测正确,pred - y为0,则gradient为0\n",
    "    #如果预测错误\n",
    "    #loss = (pred - y) * w * x\n",
    "    #d_loss/d_w = (pred - y) * x\n",
    "    return (pred - y) * x\n",
    "\n",
    "\n",
    "get_gradient(x[0], y[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "embedded-extraction",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 21.226275985986867\n",
      "20 17.71212594685117\n",
      "40 0.5362320220384172\n",
      "60 0.023755444504269863\n",
      "80 0.5199281949509837\n",
      "100 -0.48343803358109694\n",
      "120 0.5133800297353193\n",
      "140 0.5099304586338251\n",
      "160 -0.4934366296775313\n",
      "180 0.5033807598304969\n"
     ]
    }
   ],
   "source": [
    "#训练\n",
    "for epoch in range(200):\n",
    "    for i in range(N):\n",
    "        gred = get_gradient(x[i], y[i])\n",
    "        w -= gred * 1e-4\n",
    "\n",
    "    if epoch % 20 == 0:\n",
    "        print(epoch, total_loss())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "connected-extra",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.984375"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#测试\n",
    "correct = 0\n",
    "for i in range(N):\n",
    "    pred = 1 if predict(x[i]) > 0.5 else 0\n",
    "\n",
    "    if pred == y[i]:\n",
    "        correct += 1\n",
    "\n",
    "correct / N"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
